{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "!mkdir -p {DATA_DIR} {RUBERT_DIR} {MODEL_DIR}\n",
    "s3 = S3()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(TEST):\n",
    "    s3.download(S3_TEST, TEST)\n",
    "    s3.download(S3_TRAIN, TRAIN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(RUBERT_VOCAB):\n",
    "    s3.download(S3_RUBERT_VOCAB, RUBERT_VOCAB)\n",
    "    s3.download(S3_RUBERT_EMB, RUBERT_EMB)\n",
    "    s3.download(S3_RUBERT_ENCODER, RUBERT_ENCODER)\n",
    "    s3.download(S3_RUBERT_MLM, RUBERT_MLM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = BERTVocab.load(RUBERT_VOCAB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = RuBERTConfig()\n",
    "emb = BERTEmbedding.from_config()\n",
    "encoder = BERTEncoder.from_config()\n",
    "head = BERTMLMHead(config.emb_dim, config.vocab_size)\n",
    "model = BERTMLM(emb, encoder, head)\n",
    "\n",
    " # fix pos emb, train on short seqs\n",
    "emb.position.weight.requires_grad = False\n",
    "\n",
    "model.emb.load(RUBERT_EMB)\n",
    "model.encoder.load(RUBERT_ENCODER)\n",
    "model.head.load(RUBERT_MLM)\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "criterion = masked_flatten_cross_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(1)\n",
    "seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = BERTMLMTrainEncoder(\n",
    "    vocab,\n",
    "    seq_len=128,\n",
    "    batch_size=32,\n",
    "    shuffle_size=10000\n",
    ")\n",
    "\n",
    "lines = load_lines(TEST)\n",
    "batches = encode(lines)\n",
    "test_batches = [_.to(DEVICE) for _ in batches]\n",
    "\n",
    "lines = load_lines(TRAIN)\n",
    "batches = encode(lines)\n",
    "train_batches = (_.to(DEVICE) for _ in batches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board = TensorBoard(BOARD_NAME, RUNS_DIR)\n",
    "train_board = board.section(TRAIN_BOARD)\n",
    "test_board = board.section(TEST_BOARD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
    "model, optimizer = amp.initialize(model, optimizer, opt_level=O2)\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_meter = MLMScoreMeter()\n",
    "test_meter = MLMScoreMeter()\n",
    "\n",
    "accum_steps = 64  # 2K batch\n",
    "log_steps = 256\n",
    "eval_steps = 512\n",
    "save_steps = eval_steps * 10\n",
    "\n",
    "model.train()\n",
    "optimizer.zero_grad()\n",
    "\n",
    "for step, batch in log_progress(enumerate(train_batches)):\n",
    "    batch = process_batch(model, criterion, batch)\n",
    "    batch.loss /= accum_steps\n",
    "    \n",
    "    with amp.scale_loss(batch.loss, optimizer) as scaled:\n",
    "        scaled.backward()\n",
    "\n",
    "    score = score_mlm_batch(batch, ks=())\n",
    "    train_meter.add(score)\n",
    "\n",
    "    if every(step, log_steps):\n",
    "        train_meter.write(train_board)\n",
    "        train_meter.reset()\n",
    "\n",
    "    if every(step, accum_steps):\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        if every(step, eval_steps):\n",
    "            batches = infer_batches(model, criterion, test_batches)\n",
    "            scores = score_mlm_batches(batches)\n",
    "            test_meter.extend(scores)\n",
    "            test_meter.write(test_board)\n",
    "            test_meter.reset()\n",
    "    \n",
    "    if every(step, save_steps):\n",
    "        model.emb.dump(MODEL_EMB)\n",
    "        model.encoder.dump(MODEL_ENCODER)\n",
    "        model.mlm.dump(MODEL_MLM)\n",
    "        \n",
    "        s3.upload(MODEL_EMB, S3_MODEL_EMB)\n",
    "        s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
    "        s3.upload(MODEL_MLM, S3_MODEL_MLM)\n",
    "            \n",
    "    board.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
